import builtins
from typing import Any, Generic, List, Callable, Union, Tuple

import numpy as np

import ray
from ray.types import ObjectRef
from ray.data.block import Block, BlockAccessor, \
    BlockMetadata, T
from ray.data.impl.arrow_block import ArrowRow
from ray.util.annotations import PublicAPI
from ray.data.impl.util import _check_pyarrow_version

WriteResult = Any


@PublicAPI(stability="beta")
class Datasource(Generic[T]):
    """Interface for defining a custom ``ray.data.Dataset`` datasource.

    To read a datasource into a dataset, use ``ray.data.read_datasource()``.
    To write to a writable datasource, use ``Dataset.write_datasource()``.

    See ``RangeDatasource`` and ``DummyOutputDatasource`` for examples
    of how to implement readable and writable datasources.
    """

    def prepare_read(self, parallelism: int,
                     **read_args) -> List["ReadTask[T]"]:
        """Return the list of tasks needed to perform a read.

        Args:
            parallelism: The requested read parallelism. The number of read
                tasks should be as close to this value as possible.
            read_args: Additional kwargs to pass to the datasource impl.

        Returns:
            A list of read tasks that can be executed to read blocks from the
            datasource in parallel.
        """
        raise NotImplementedError

    def do_write(self, blocks: List[ObjectRef[Block]],
                 metadata: List[BlockMetadata],
                 **write_args) -> List[ObjectRef[WriteResult]]:
        """Launch Ray tasks for writing blocks out to the datasource.

        Args:
            blocks: List of data block references. It is recommended that one
                write task be generated per block.
            metadata: List of block metadata.
            write_args: Additional kwargs to pass to the datasource impl.

        Returns:
            A list of the output of the write tasks.
        """
        raise NotImplementedError

    def on_write_complete(self, write_results: List[WriteResult],
                          **kwargs) -> None:
        """Callback for when a write job completes.

        This can be used to "commit" a write output. This method must
        succeed prior to ``write_datasource()`` returning to the user. If this
        method fails, then ``on_write_failed()`` will be called.

        Args:
            write_results: The list of the write task results.
            kwargs: Forward-compatibility placeholder.
        """
        pass

    def on_write_failed(self, write_results: List[ObjectRef[WriteResult]],
                        error: Exception, **kwargs) -> None:
        """Callback for when a write job fails.

        This is called on a best-effort basis on write failures.

        Args:
            write_results: The list of the write task result futures.
            error: The first error encountered.
            kwargs: Forward-compatibility placeholder.
        """
        pass


@PublicAPI(stability="beta")
class ReadTask(Callable[[], Block]):
    """A function used to read a block of a dataset.

    Read tasks are generated by ``datasource.prepare_read()``, and return
    a ``ray.data.Block`` when called. Metadata about the read operation can
    be retrieved via ``get_metadata()`` prior to executing the read.

    Ray will execute read tasks in remote functions to parallelize execution.
    """

    def __init__(self, read_fn: Callable[[], Block], metadata: BlockMetadata):
        self._metadata = metadata
        self._read_fn = read_fn

    def get_metadata(self) -> BlockMetadata:
        return self._metadata

    def __call__(self) -> Block:
        return self._read_fn()


class RangeDatasource(Datasource[Union[ArrowRow, int]]):
    """An example datasource that generates ranges of numbers from [0..n).

    Examples:
        >>> source = RangeDatasource()
        >>> ray.data.read_datasource(source, n=10).take()
        ... [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    """

    def prepare_read(self,
                     parallelism: int,
                     n: int,
                     block_format: str = "list",
                     tensor_shape: Tuple = (1, )) -> List[ReadTask]:
        read_tasks: List[ReadTask] = []
        block_size = max(1, n // parallelism)

        # Example of a read task. In a real datasource, this would pull data
        # from an external system instead of generating dummy data.
        def make_block(start: int, count: int) -> Block:
            if block_format == "arrow":
                return pyarrow.Table.from_arrays(
                    [np.arange(start, start + count)], names=["value"])
            elif block_format == "tensor":
                tensor = TensorArray(
                    np.ones(tensor_shape, dtype=np.int64) * np.expand_dims(
                        np.arange(start, start + count),
                        tuple(range(1, 1 + len(tensor_shape)))))
                return pyarrow.Table.from_pydict({"value": tensor})
            else:
                return list(builtins.range(start, start + count))

        i = 0
        while i < n:
            count = min(block_size, n - i)
            if block_format == "arrow":
                _check_pyarrow_version()
                import pyarrow
                schema = pyarrow.Table.from_pydict({"value": [0]}).schema
            elif block_format == "tensor":
                _check_pyarrow_version()
                from ray.data.extensions import TensorArray
                import pyarrow
                tensor = TensorArray(
                    np.ones(tensor_shape, dtype=np.int64) * np.expand_dims(
                        np.arange(0, 10), tuple(
                            range(1, 1 + len(tensor_shape)))))
                schema = pyarrow.Table.from_pydict({"value": tensor}).schema
            elif block_format == "list":
                schema = int
            else:
                raise ValueError("Unsupported block type", block_format)
            read_tasks.append(
                ReadTask(
                    lambda i=i, count=count: make_block(i, count),
                    BlockMetadata(
                        num_rows=count,
                        size_bytes=8 * count,
                        schema=schema,
                        input_files=None)))
            i += block_size

        return read_tasks


class DummyOutputDatasource(Datasource[Union[ArrowRow, int]]):
    """An example implementation of a writable datasource for testing.

    Examples:
        >>> output = DummyOutputDatasource()
        >>> ray.data.range(10).write_datasource(output)
        >>> assert output.num_ok == 1
    """

    def __init__(self):
        # Setup a dummy actor to send the data. In a real datasource, write
        # tasks would send data to an external system instead of a Ray actor.
        @ray.remote
        class DataSink:
            def __init__(self):
                self.rows_written = 0
                self.enabled = True

            def write(self, block: Block) -> str:
                block = BlockAccessor.for_block(block)
                if not self.enabled:
                    raise ValueError("disabled")
                self.rows_written += block.num_rows()
                return "ok"

            def get_rows_written(self):
                return self.rows_written

            def set_enabled(self, enabled):
                self.enabled = enabled

        self.data_sink = DataSink.remote()
        self.num_ok = 0
        self.num_failed = 0

    def do_write(self, blocks: List[ObjectRef[Block]],
                 metadata: List[BlockMetadata],
                 **write_args) -> List[ObjectRef[WriteResult]]:
        tasks = []
        for b in blocks:
            tasks.append(self.data_sink.write.remote(b))
        return tasks

    def on_write_complete(self, write_results: List[WriteResult]) -> None:
        assert all(w == "ok" for w in write_results), write_results
        self.num_ok += 1

    def on_write_failed(self, write_results: List[ObjectRef[WriteResult]],
                        error: Exception) -> None:
        self.num_failed += 1


class RandomIntRowDatasource(Datasource[ArrowRow]):
    """An example datasource that generates rows with random int64 columns.

    Examples:
        >>> source = RandomIntRowDatasource()
        >>> ray.data.read_datasource(source, n=10, num_columns=2).take()
        ... {'c_0': 1717767200176864416, 'c_1': 999657309586757214}
        ... {'c_0': 4983608804013926748, 'c_1': 1160140066899844087}
    """

    def prepare_read(self, parallelism: int, n: int,
                     num_columns: int) -> List[ReadTask]:
        _check_pyarrow_version()
        import pyarrow

        read_tasks: List[ReadTask] = []
        block_size = max(1, n // parallelism)

        def make_block(count: int, num_columns: int) -> Block:
            return pyarrow.Table.from_arrays(
                np.random.randint(
                    np.iinfo(np.int64).max,
                    size=(num_columns, count),
                    dtype=np.int64),
                names=[f"c_{i}" for i in range(num_columns)])

        schema = pyarrow.Table.from_pydict(
            {f"c_{i}": [0]
             for i in range(num_columns)}).schema

        i = 0
        while i < n:
            count = min(block_size, n - i)
            read_tasks.append(
                ReadTask(
                    lambda count=count, num_columns=num_columns:
                        make_block(count, num_columns),
                    BlockMetadata(
                        num_rows=count,
                        size_bytes=8 * count * num_columns,
                        schema=schema,
                        input_files=None)))
            i += block_size

        return read_tasks
